-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: add arg to enable dft in liger #3125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughIntroduces an optional configuration liger_use_token_scaling for Liger’s fused linear cross-entropy (FLCE). Updates README usage. Adds the field to LigerArgs. In plugin pre_model_load, when both FLCE and token scaling are enabled, runtime patches force use_token_scaling=True for both the FLCE function and loss class initializer. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
📖 Documentation Preview: https://68b68ca04d39d4dd3d02e0fc--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 0b2795f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/axolotl/integrations/liger/README.md (1)
21-23: Clarify version/feature dependency for token scaling.Readers need to know this only works with FLCE and a Liger-Kernel build that includes the token-scaling feature (PR #860+).
# FLCE-specific -liger_use_token_scaling: true +# Requires Liger-Kernel with token-scaling support (PR #860+) and FLCE enabled +liger_use_token_scaling: truesrc/axolotl/integrations/liger/args.py (1)
38-46: Guard misuse and clarify description.Warn when token scaling is set without FLCE, and note the dependency in the field description.
liger_use_token_scaling: bool | None = Field( default=None, json_schema_extra={ "description": ( - "Enables use_token_scaling in fused_linear_cross_entropy. " - "When True, each token's loss is multiplied by its predicted probability (detached from gradients)." + "Enables use_token_scaling in fused_linear_cross_entropy (FLCE). " + "When True, each token's loss is multiplied by its predicted probability (detached from gradients). " + "Requires `liger_fused_linear_cross_entropy: true` and a Liger-Kernel build with token-scaling support." ) }, )Add a validator (outside this hunk) to warn when ineffective:
# place near other @model_validator(mode="before") @model_validator(mode="before") @classmethod def check_token_scaling_requires_flce(cls, data): if data.get("liger_use_token_scaling") and not data.get("liger_fused_linear_cross_entropy"): LOG.warning( "`liger_use_token_scaling: true` has no effect unless `liger_fused_linear_cross_entropy: true` is also set." ) return data
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
src/axolotl/integrations/liger/README.md(1 hunks)src/axolotl/integrations/liger/args.py(2 hunks)src/axolotl/integrations/liger/plugin.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/integrations/liger/args.py (1)
38-46: LGTM: adds well-scoped opt-in flag with clear schema.src/axolotl/integrations/liger/plugin.py (1)
51-76: Add missinginspectimport and verify patch logic
- In
src/axolotl/integrations/liger/plugin.py, addimport inspectbefore usinginspect.signature.- Ensure the guard skips the monkey-patch when
use_token_scalingisn’t in the function or__init__signature; manually test in an environment with Liger-Kernel installed to confirm no TypeError is raised.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
This should be a draft, right? Since it needs a new Liger release |
Description
Adds support for linkedin/Liger-Kernel#860
Enable via
liger_use_token_scaling: true>0.6.2Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Documentation